{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7e53efda-8330-4f93-8b0a-a26d25c95ce1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c18328a4fa614a84b23facc7e233de76",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Anneme onu ne kadar sevdiğimi anlatan bir mektup yaz<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Sevgili Anne,\n",
      "\n",
      "Bugün sana, seni ne kadar çok sevdiğimi ifade etmek istiyorum. Senin için hissettiklerim kelimelerle ifade edilemeyecek kadar derin ve güçlü. Sen benim hayatımın ışığı, en büyük destekçim ve en sevdiğim insansın.\n",
      "\n",
      "Her gün seninle geçirdiğim her an için minnettarım. Senin sevgin, rehberliğin ve desteğin olmadan hayatım çok farklı olurdu. Bana verdiğin\n"
     ]
    }
   ],
   "source": [
    "# pip install transformers==4.41.1\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "# https://huggingface.co/CohereForAI/aya-23-8B/\n",
    "model_id = \"/home/lyz/hf-models/aya-23-8B/\" # 本地路径\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_id)\n",
    "\n",
    "# Format message with the command-r-plus chat template\n",
    "messages = [{\"role\": \"user\", \"content\": \"Anneme onu ne kadar sevdiğimi anlatan bir mektup yaz\"}]\n",
    "input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\")\n",
    "## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Anneme onu ne kadar sevdiğimi anlatan bir mektup yaz<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n",
    "\n",
    "gen_tokens = model.generate(\n",
    "    input_ids, \n",
    "    max_new_tokens=100, \n",
    "    do_sample=True, \n",
    "    temperature=0.3,\n",
    "    )\n",
    "\n",
    "gen_text = tokenizer.decode(gen_tokens[0])\n",
    "print(gen_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d274a9-b0da-425c-82bb-bc93c0a8c254",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 3/1000 [00:30<2:58:12, 10.72s/it]"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "test_lines = open('testA.nl').readlines()\n",
    "\n",
    "result = []\n",
    "for line in tqdm(test_lines):\n",
    "    # Format message with the command-r-plus chat template\n",
    "    messages = [{\"role\": \"user\", \"content\": f\"将下面荷兰语翻译为中文：{line}\"}]\n",
    "    input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\")\n",
    "    \n",
    "    gen_tokens = model.generate(\n",
    "        input_ids, \n",
    "        max_new_tokens=100, \n",
    "        do_sample=True, \n",
    "        temperature=0.3,\n",
    "    )\n",
    "    \n",
    "    gen_text = tokenizer.decode(gen_tokens[0])\n",
    "    result.append(\n",
    "        gen_text.split('<|CHATBOT_TOKEN|>')[1].split('<|')[0]\n",
    "    )\n",
    "\n",
    "    with open('submit.csv', 'a') as up:\n",
    "        up.write(result[-1].replace('\\n', '') + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b375b3-8a5f-411d-b1e1-049e6350c5c5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
